# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
import time

from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf

from official.nlp.bert import configs as bert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib as squad_lib_wp
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils


flags.DEFINE_string('vocab_file', None,
                    'The vocabulary file that the BERT model was trained on.')

# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()

FLAGS = flags.FLAGS


def train_squad(strategy,
                input_meta_data,
                custom_callbacks=None,
                run_eagerly=False,
                init_checkpoint=None):
  """Run bert squad training."""
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
  init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
  run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
                               custom_callbacks, run_eagerly, init_checkpoint)


def predict_squad(strategy, input_meta_data):
  """Makes predictions for the squad dataset."""
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
  run_squad_helper.predict_squad(
      strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)


def eval_squad(strategy, input_meta_data):
  """Evaluate on the squad dataset."""
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
  eval_metrics = run_squad_helper.eval_squad(
      strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
  return eval_metrics


def export_squad(model_export_path, input_meta_data):
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.

  Raises:
    Export path is not specified, got an empty string or None.
  """
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
  run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)


def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)

  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))

  if FLAGS.mode == 'export_only':
    export_squad(FLAGS.model_export_path, input_meta_data)
    return

  # Configures cluster spec for multi-worker distribution strategy.
  if FLAGS.num_gpus > 0:
    _ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
                                             FLAGS.task_index)
  strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
      all_reduce_alg=FLAGS.all_reduce_alg,
      tpu_address=FLAGS.tpu)

  if 'train' in FLAGS.mode:
    if FLAGS.log_steps:
      custom_callbacks = [keras_utils.TimeHistory(
          batch_size=FLAGS.train_batch_size,
          log_steps=FLAGS.log_steps,
          logdir=FLAGS.model_dir,
      )]
    else:
      custom_callbacks = None

    train_squad(
        strategy,
        input_meta_data,
        custom_callbacks=custom_callbacks,
        run_eagerly=FLAGS.run_eagerly,
    )
  if 'predict' in FLAGS.mode:
    predict_squad(strategy, input_meta_data)
  if 'eval' in FLAGS.mode:
    eval_metrics = eval_squad(strategy, input_meta_data)
    f1_score = eval_metrics['final_f1']
    logging.info('SQuAD eval F1-score: %f', f1_score)
    summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
    summary_writer = tf.summary.create_file_writer(summary_dir)
    with summary_writer.as_default():
      # TODO(lehou): write to the correct step number.
      tf.summary.scalar('F1-score', f1_score, step=0)
      summary_writer.flush()
    # Also write eval_metrics to json file.
    squad_lib_wp.write_to_json_files(
        eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
    time.sleep(60)


if __name__ == '__main__':
  flags.mark_flag_as_required('bert_config_file')
  flags.mark_flag_as_required('model_dir')
  app.run(main)
